############################################################
# Summary: This script reproduces the validation step "comparison with human-labeled data"

# Output:
# - Table 1:  Top-3 model performances (Macro-F1) for balanced and unbalanced datasets.
# - Figure 3: Macro-F1 performance comparison by model 

# Reading Data
column_names <- c("dataset", "avg_f1","sd_f1","avg_macro_f1", "sd_macro_f1", "strategy", "model", 
                  "shot_type", "class", "case_count")

df_svm <- read_csv("data/results/classification_results_svm_final.csv") |> 
  mutate(strategy = "supervised",
         model = "svm",
         shot_type = "none",
         case_count = 999) |> #tba
  rename(sd_macro_f1 = sd_macro_F1,
         avg_macro_f1 = avg_macro_F1,
         avg_f1=avg_F1_score,
         sd_f1 = sd_F1_score) |> 
  distinct(across(all_of(column_names)))

df_sup <- read_csv("data/results/classification_results_semisupervised_final.csv") |> 
  mutate(strategy = "semisupervised",
         class = case_when(
           class == 0 ~ "none",
           class == 1 ~ "emp",
           class == 2 ~ "dur"
         ),
         shot_type = "none") |> 
  rename(sd_macro_f1 = sd_macro_F1,
         avg_macro_f1 = macro_F1,
         avg_f1=avg_F1_Score,
         sd_f1 = sd_F1_Score) |> 
  dplyr::filter(abs(avg_f1 - 0.4667375) > 1e-6) |> 
  distinct(across(all_of(column_names))) 


df_pro <- bind_rows(read_csv("data/results/classification_results_prompting_final.csv"),
                    read_csv("data/results/classification_results_prompting_gpt_mini.csv") |> mutate(model = "gpt-40-mini")) |> 
  mutate(strategy = "prompting" ) |> 
  filter(model != "gpt-40-mini") |> 
  mutate(dataset = case_when(
    dataset == "weak_balanced" ~ "combined_balanced",
    dataset == "weak_unbalanced" ~ "combined_unbalanced",
    T ~dataset
  )) |> 
  rename(sd_macro_f1 = sd_macro_F1,
         avg_macro_f1 = macro_F1,
         avg_f1=avg_F1_Score,
         sd_f1 = sd_F1_Score) |> 
  distinct(across(all_of(column_names)))


df_full <- bind_rows(df_svm, df_sup, df_pro)
df <- distinct(df_full, dataset, avg_macro_f1, sd_macro_f1, strategy, model, 
               shot_type) |> 
  mutate(dataset_dist = case_when(
    dataset == "combined_balanced" | dataset == "strong_balanced"~ "balanced",
    dataset == "combined_unbalanced"| dataset == "strong_unbalanced" ~ "unbalanced")) |> 
  mutate(dataset_type = case_when(
    dataset == "combined_balanced" | dataset == "combined_unbalanced"~ "combined",
    dataset == "strong_balanced"| dataset == "strong_unbalanced" ~ "strong")) 


### Table 1 ####

# Picking the best three perfromances for balanced and unbalanded datasets
df |> 
  filter(dataset_dist == "unbalanced") |> 
  arrange(desc(avg_macro_f1)) |> 
  mutate(avg_macro_f1 = round(avg_macro_f1,2), sd_macro_f1 = round(sd_macro_f1,2)) |> 
  head(3) |> 
  select(model, strategy, dataset,shot_type, avg_macro_f1, sd_macro_f1) 

df |> 
  filter(dataset_dist == "balanced") |> 
  arrange(desc(avg_macro_f1)) |> 
  mutate(avg_macro_f1 = round(avg_macro_f1,2), sd_macro_f1 = round(sd_macro_f1,2)) |> 
  head(3) |> 
  select(model, strategy, dataset,shot_type, avg_macro_f1, sd_macro_f1) 

## Visualisation

df <- df %>%
  mutate(dataset_dist = case_when(
    str_detect(dataset, "_balanced") ~ "Balanced",
    str_detect(dataset, "_unbalanced") ~ "Unbalanced",
    TRUE ~ NA_character_  # Catch any unexpected cases
  ))


unique(df$model)

df <- df %>%
  mutate(model = factor(model, levels = c("svm", "xlm-roberta-base","llama","gpt", "deepseak"), 
                        labels = c("SVM", "XLM-RoBERTa","Llama 3-8B","GPT-4o", "Deepseek-V3" ))) |> 
  mutate(dataset_type = factor(dataset_type, levels = c("strong","combined")))


unique(df$model)

# Filter out zero-shot data
df_filtered <- df |> filter(shot_type != "zero") |> 
  mutate(dataset_type = case_when(
    dataset_type == "strong" ~ "strong",
    dataset_type == "combined" ~ "weak+strong"
  ))

##### Figure 3 ####


# Create the balanced dataset plot
df_balanced <- df_filtered |> filter(dataset_dist == "Balanced")
if (nrow(df_balanced) > 0) {
  p_balanced <- ggplot(df_balanced, 
                       aes(x = model, y = avg_macro_f1, color = dataset_type, shape = dataset_type)) +
    geom_errorbar(aes(group = dataset_type, ymin = avg_macro_f1 - sd_macro_f1, ymax = avg_macro_f1 + sd_macro_f1),
                  color = "black", width = 0.2, position = position_dodge(width = 0.4)) +
    geom_point(size = 3, position = position_dodge(width = 0.4)) +
    labs(title = "Balanced Test & Train Datasets", x = "Model", y = "Macro-F1", color = "Dataset", shape = "Dataset") +
    ylim(0.5, 0.8) +
    theme_minimal() +
    theme(
      plot.title = element_text(size = 10, face = "bold"),  
      strip.background = element_blank(),
      strip.text.x = element_blank(),
      axis.title.x = element_blank(),
      legend.title = element_text(size = 10, face = "bold")
    ) +
    facet_wrap(~ model + strategy, scales = "free_x", labeller = label_both, nrow = 1) +
    theme(panel.spacing = unit(0.1, "lines"))
} else {
  print("No balanced dataset found!")
  p_balanced <- ggplot() + labs(title = "No balanced data available")
}

# Create the unbalanced dataset plot
df_unbalanced <- df_filtered |> filter(dataset_dist == "Unbalanced")
if (nrow(df_unbalanced) > 0) {
  p_unbalanced <- ggplot(df_unbalanced, 
                         aes(x = model, y = avg_macro_f1, color = dataset_type, shape = dataset_type)) +
    geom_errorbar(aes(group = dataset_type, ymin = avg_macro_f1 - sd_macro_f1, ymax = avg_macro_f1 + sd_macro_f1),
                  color = "black", width = 0.2, position = position_dodge(width = 0.4)) +
    geom_point(size = 3, position = position_dodge(width = 0.4)) +
    labs(title = "Unbalanced Test & Train Datasets", x = "Model", y = "Macro-F1") +
    ylim(0.5, 0.8) +
    theme_minimal() +
    theme(
      plot.title = element_text(size = 10, face = "bold"),  
      strip.background = element_blank(),
      strip.text.x = element_blank(),
      axis.title.x = element_blank(),
      legend.position = "none"  # Remove second legend
    ) +
    facet_wrap(~ model + strategy, scales = "free_x", labeller = label_both, nrow = 1)
} else {
  print("No unbalanced dataset found!")
  p_unbalanced <- ggplot() + labs(title = "No unbalanced data available")
}

# Arrange plots in two rows
ggarrange(p_balanced, p_unbalanced, ncol = 1, nrow = 2, align = "v")

ggsave("output/Figure_3_F1scores.png", width = 7, height = 4, dpi = 900)
